//===- TransformTypes.cpp - Transform Dialect Type Definitions ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Compiler.h"

using namespace mlir;

#include "mlir/Dialect/Transform/IR/TransformTypeInterfaces.cpp.inc"

// These are automatically generated by ODS but are not used as the Transform
// dialect uses a different dispatch mechanism to support dialect extensions.
LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
LLVM_ATTRIBUTE_UNUSED static LogicalResult
generatedTypePrinter(Type def, AsmPrinter &printer);

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"

void transform::TransformDialect::initializeTypes() {
  addTypesChecked<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
      >();
}

//===----------------------------------------------------------------------===//
// transform::AnyOpType
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::AnyOpType::checkPayload(Location loc,
                                   ArrayRef<Operation *> payload) const {
  return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// transform::OperationType
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::OperationType::checkPayload(Location loc,
                                       ArrayRef<Operation *> payload) const {
  OperationName opName(getOperationName(), loc.getContext());
  for (Operation *op : payload) {
    if (opName != op->getName()) {
      DiagnosedSilenceableFailure diag =
          emitSilenceableError(loc) << "incompatible payload operation name";
      diag.attachNote(op->getLoc()) << "payload operation";
      return diag;
    }
  }

  return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// transform::ParamType
//===----------------------------------------------------------------------===//

LogicalResult
transform::ParamType::verify(function_ref<InFlightDiagnostic()> emitError,
                             Type type) {
  IntegerType intType = type.dyn_cast<IntegerType>();
  if (!intType || intType.getWidth() > 64)
    return emitError() << "only supports integer types with width <=64";
  return success();
}

DiagnosedSilenceableFailure
transform::ParamType::checkPayload(Location loc,
                                   ArrayRef<Attribute> payload) const {
  for (Attribute attr : payload) {
    auto integerAttr = attr.dyn_cast<IntegerAttr>();
    if (!integerAttr) {
      return emitSilenceableError(loc)
             << "expected parameter to be an integer attribute, got " << attr;
    }
    if (integerAttr.getType() != getType()) {
      return emitSilenceableError(loc)
             << "expected the type of the parameter attribute ("
             << integerAttr.getType() << ") to match the parameter type ("
             << getType() << ")";
    }
  }
  return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// transform::AnyValueType
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::AnyValueType::checkPayload(Location loc,
                                      ArrayRef<Value> payload) const {
  return DiagnosedSilenceableFailure::success();
}
